from environment import Environment
from agent import Agent
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import pyreadr
import seaborn as sb


def method_comparison(trials, try_beta=1):
    real_conf = trials['confidence']
    #pred_conf_pai_sam = Agent.pred_conf(trials, beta=try_beta, conf_type='pai', sampling=True)
    #pred_conf_pai_sam2 = Agent.pred_conf(trials, beta=try_beta, conf_type='pai2', sampling=True)
    pred_conf_pai = Agent.pred_conf(trials, beta=try_beta, conf_type='pai', sampling=False)
    #pred_conf_pai2 = Agent.pred_conf(trials, beta=try_beta, conf_type='pai2', sampling=False)
    #pred_conf_ev_sam = Agent.pred_conf(trials, conf_type='ev', sampling=True)
    pred_conf_ev = Agent.pred_conf(trials, conf_type='ev', sampling=False)

    aic_pai = Environment.get_aic(real_conf, pred_conf_pai)
    aic_ev = Environment.get_aic(real_conf, pred_conf_ev)

    print(f'aic value for planning as inference: \n {aic_pai}')
    #print(f'aic value for planning as inference 2: \n {Environment.get_aic(real_conf, pred_conf_pai2)}')
    print(f'aic value for expected value: \n {aic_ev}')
    #print(f'aic value for planning as inference w/ sampling: \n {Environment.get_aic(real_conf, pred_conf_pai_sam)}')
    #print(f'aic value for planning as inference w/ sampling 2: \n {Environment.get_aic(real_conf, pred_conf_pai_sam2)}')
    #print(f'aic value for expected value w/ sampling: \n {Environment.get_aic(real_conf, pred_conf_ev_sam)}')
    return aic_pai, aic_ev

def GraphOne(aic_pais, aic_evs):
    # Graph 1
    w = 10  # bin width
    diff_aic = aic_evs - aic_pais
    plt.hist(diff_aic, bins=np.arange(min(diff_aic), max(diff_aic) + w, w))
    plt.show()

def GraphTwo(aic_pais, aic_evs):
    # Graph 2
    _data_sub = []
    for i in range(len(aic_pais)):
        _data_sub.append([f"{i + 1}"])
        _data_sub[i].append(aic_pais[i])
        _data_sub[i].append(aic_evs[i])

    _df = pd.DataFrame(_data_sub, columns=["subject", "Decision", "E[x]"])
    _df.plot(x="subject", y=["Decision", "E[x]"], kind="bar",
             figsize=(8, 4), color=['tab:blue', 'tab:gray'])
    plt.xlabel("confidence frameworks")
    plt.ylabel("average AIC value")
    # plt.title(f'??')
    plt.show()

def GraphThree(aic_pais, aic_evs):
    # Graph 3: average aics by method
    _data_avg = []
    _data_avg.append([""])
    _data_avg[0].append(np.mean(aic_pais))
    _data_avg[0].append(np.mean(aic_evs))

    _df = pd.DataFrame(_data_avg, columns=["subject", "Decision", "E[x]"])

    y_err1 = np.std(aic_pais) / np.sqrt((len(aic_pais)))
    y_err2 = np.std(aic_evs) / np.sqrt((len(aic_evs)))
    print(y_err1)
    print(y_err2)

    _df.plot(x="subject", y=["Decision", "E[x]"], kind="bar",
             figsize=(8, 4), color=['tab:blue', 'tab:gray'], yerr=[[y_err1], [y_err2]])
    plt.xlabel("confidence frameworks")
    plt.ylabel("average AIC value")
    # plt.title(f'??')
    plt.show()

def GraphFour(aic_pais, aic_evs):
    # Graph 4: aics by subject
    x = np.arange(1, num_subs + 1, 1)
    plt.plot(x, aic_pais, '.', label='pai')
    plt.plot(x, aic_evs, '.', label='ev')
    plt.legend()
    plt.show()

    print("aic values for pai")
    print(aic_pais)
    print("aic values for expected value")
    print(aic_evs)

def GraphFive(aic_pais, aic_evs):
    # Graph 5: violin plot??
    diff_aics = aic_evs - aic_pais
    _data_sub = []
    for i in range(len(diff_aics)):
        _data_sub.append([f"{i + 1}"])
        _data_sub[i].append(diff_aics[i])
        _data_sub[i].append(1)
        """if diff_aics[i] > 0:
            _data_sub[i].append(1)
        else:
            _data_sub[i].append(-1)"""

    _df = pd.DataFrame(_data_sub, columns=["??", "Difference", "test_temp"])
    sb.set(style='whitegrid')
    sb.violinplot(x='test_temp', y='Difference', data=_df)
    plt.show()

def GraphSix(aic_pais, aic_evs):
    # scatter plot
    diff_aics = aic_evs - aic_pais
    horz_x = [0.96, 1.04]
    horz_y1 = [-2, -2]
    # horz_y2 = [0, 0]
    # horz_y3 = [2, 2]
    x = np.ones(len(aic_pais))

    font = {'weight': 'normal', 'size': 20}
    plt.rc('font', **font)
    axes = plt.gca()
    figure = plt.gcf()
    axes.spines['top'].set_visible(False)
    axes.spines['right'].set_visible(False)
    axes.tick_params(top=False, right=False)
    axes.tick_params(direction='out', width='4')
    figure.set_size_inches(2, 6)
    plt.tight_layout()
    plt.plot(x, diff_aics, '.')
    plt.plot(horz_x, horz_y1, 'k', linestyle='dashed')
    # plt.plot(horz_x, horz_y2, 'k')
    # plt.plot(horz_x, horz_y3, 'k')
    plt.yticks([-25, 0, 100, 200, 300, 400, 500, 600])
    plt.xticks([])
    plt.xlabel('')
    plt.ylabel(r'$\Delta$ AIC values')
    # plt.yscale('symlog')
    # plt.grid(axis='y')
    plt.show()

def GraphSeven(aic_pais, aic_evs):
    data = [
        aic_pais,
        aic_evs
    ]
    labels = ['decision', 'expected value']
    colors = ['tab:blue', 'tab:orange']

    fig, ax = plt.subplots()
    ax.set_ylabel('aic')

    bplot = ax.boxplot(data, patch_artist=True) #, tick_labels=labels)
    plt.xticks([1, 2], ['decision', 'expected value'])

    # fill with colors
    for patch, color in zip(bplot['boxes'], colors):
        patch.set_facecolor(color)
    for median in bplot['medians']:
        median.set_color('black')

    plt.show()

def GraphEight(aic_pais, aic_evs):
    x = np.arange(0, 810)
    font = {'weight': 'normal', 'size': 20}
    plt.rc('font', **font)
    plt.plot(aic_pais, aic_evs, 'o')
    plt.plot(x, 'r--')
    plt.xlabel('AIC for soft optimality')
    plt.ylabel('AIC for expected value ratio')
    plt.xlim((0,810))
    plt.ylim((0,810))
    plt.show()



if __name__ == '__main__':
    data = pyreadr.read_r('Data.RData')
    data = data['Data']

    print("old results for all subjects")
    method_comparison(data)

    grouped_data = data.groupby(data.subNr)
    num_subs = int(data.iloc[-1]['subNr'])
    aic_pais = np.zeros(num_subs)
    aic_evs = np.zeros(num_subs)
    print(num_subs)
    for i in range(1, num_subs+1):
        print(f'old results for subject {i}')
        sub_data = grouped_data.get_group(i)
        aic_pais[i-1], aic_evs[i-1] = method_comparison(sub_data)

    # GraphOne(aic_pais, aic_evs)
    # GraphTwo(aic_pais, aic_evs)
    # GraphThree(aic_pais, aic_evs)
    # GraphFour(aic_pais, aic_evs)
    # GraphFive(aic_pais, aic_evs)
    # GraphSix(aic_pais, aic_evs)
    # GraphSeven(aic_pais, aic_evs)
    GraphEight(aic_pais, aic_evs)



    print(aic_pais)
    print(aic_evs)

    """data_string = "const data = [ \n"
    for i in range(len(aic_pais)):
        data_string += f"{{ subject: {i+1},  decision: {aic_pais[i]}, expected_value: {aic_evs[i]} }}, \n"
    data_string += "];"
    print(data_string)"""

    aic_diff = aic_evs - aic_pais

    print(aic_diff)

    for i, j in enumerate(aic_diff):
        print(i+1, 0 < aic_diff[i])



